package com.je.gateway.filter;

import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.je.gateway.util.IpUtil;
import com.je.gateway.util.SecurityUserHolder;
import com.je.gateway.util.SpringContextHolder;
import org.apache.servicecomb.core.Invocation;
import org.apache.servicecomb.foundation.common.http.HttpStatus;
import org.apache.servicecomb.foundation.vertx.http.HttpServletRequestEx;
import org.apache.servicecomb.swagger.invocation.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment;
import org.springframework.data.redis.core.StringRedisTemplate;

import java.util.ArrayList;
import java.util.List;

/**
 * IP限制管理策略拦截器
 */
public class GatewayIpRestrictionFilter extends AbstractHttpServerFilter {

    private static final Logger logger = LoggerFactory.getLogger(GatewayIpRestrictionFilter.class);

    @Override
    public int getOrder() {
        return 2;
    }

    @Override
    public Response afterReceiveRequest(Invocation invocation, HttpServletRequestEx requestEx) {
        //如果不拦截，则账号配置无效
        if (Strings.isNullOrEmpty(SecurityUserHolder.getCurrentAccountId())) {
            return null;
        }

        String isOpen = requestEx.getHeader("isOpen");
        if ("1".equals(isOpen)) {
            return null;
        }

        StringRedisTemplate stringRedisTemplate = SpringContextHolder.getBean(StringRedisTemplate.class);
        String value = stringRedisTemplate.opsForValue().get("ipRestriction");
        if (value == null) {
            return null;
        }

        JSONObject valueObj = JSON.parseObject(value);
        if (!"1".equals(valueObj.getString("switch"))) {
            return null;
        }

        logger.info("Check IP restrict begin! checked the IP restrict rules cache!");
        JSONArray rulesArray = valueObj.getJSONArray("rules");
        //如果没有规则，则通过
        if (rulesArray == null || rulesArray.isEmpty()) {
            logger.info("The IP restrict rules is empty,now let this request continue!");
            return null;
        }

        String clientIp = IpUtil.getIpAddress(requestEx);
        logger.info("Check the client {} for IP restrict!", clientIp);
        if (validAccountIpAddress(clientIp, rulesArray) || validAccountIpCidr(clientIp, rulesArray)) {
            logger.info("Check the client {},account {} for IP restrict success,now forbidden this request continue!", clientIp, SecurityUserHolder.getCurrentAccountName());
            return null;
        }
        logger.info("Check the client {}, account {} for IP restrict failure,now forbidden this request accessing!", clientIp, SecurityUserHolder.getCurrentAccountName());
        return Response.create(new HttpStatus(403, ""), "Your client IP is restricted from accessing");
    }

    /**
     * 校验客户端ip是否在地址列表中
     *
     * @param clientIp   客户端Ip
     * @param rulesArray 规则数组
     * @return
     */
    private boolean validAccountIpAddress(String clientIp, JSONArray rulesArray) {
        List<String> accountRestrictRuleList = filterAccountRuleContent(rulesArray, "ADDRESS");
        //默认增加回环地址和默认的ip
        accountRestrictRuleList.add("127.0.0.1");
        accountRestrictRuleList.add("localhost");
        Environment environment = SpringContextHolder.getBean(Environment.class);
        String defaultAllowIps = environment.getProperty("restrict.defaultAllowClientIps");
        if (!Strings.isNullOrEmpty(defaultAllowIps)) {
            accountRestrictRuleList.addAll(Splitter.on(",").splitToList(defaultAllowIps));
        }
        if (accountRestrictRuleList.contains(clientIp)) {
            return true;
        }
        return false;
    }

    /**
     * 校验客户端ip是否在cidr范围内
     *
     * @param clientIp   客户端ip
     * @param rulesArray 规则数组
     * @return
     */
    private boolean validAccountIpCidr(String clientIp, JSONArray rulesArray) {
        List<String> accountRestrictRuleList = filterAccountRuleContent(rulesArray, "RANGE");
        Environment environment = SpringContextHolder.getBean(Environment.class);
        String defaultAllowCidrs = environment.getProperty("restrict.defaultAllowClientCidrs");
        if (!Strings.isNullOrEmpty(defaultAllowCidrs)) {
            accountRestrictRuleList.addAll(Splitter.on(",").splitToList(defaultAllowCidrs));
        }
        for (String eachCidr : accountRestrictRuleList) {
            if (IpUtil.isInRange(clientIp, eachCidr)) {
                return true;
            }
        }
        return false;
    }

    /**
     * 过滤账号规则内容
     *
     * @param rulesArray
     * @return
     */
    private List<String> filterAccountRuleContent(JSONArray rulesArray, String ruleType) {
        //获取配置的ip列表
        List<String> accountRestrictRuleList = new ArrayList<>();
        for (int i = 0; i < rulesArray.size(); i++) {
            //如果不是地址类型则跳过
            if (!ruleType.equals(rulesArray.getJSONObject(i).getString("ruleType"))) {
                continue;
            }
            //如果规则内容为空则跳过
            if (Strings.isNullOrEmpty(rulesArray.getJSONObject(i).getString("ruleContent"))) {
                continue;
            }

            //如果账户是*为全部匹配，则加入列表
            if ("*".equals(rulesArray.getJSONObject(i).getString("accountName"))) {
                accountRestrictRuleList.addAll(Splitter.on(",").splitToList(rulesArray.getJSONObject(i).getString("ruleContent")));
                continue;
            }

            //如果当前账号为空，则跳过
            if (SecurityUserHolder.getCurrentAccountId() == null
                    || !SecurityUserHolder.getCurrentAccountId().equals(rulesArray.getJSONObject(i).getString("accountId"))) {
                continue;
            }
            accountRestrictRuleList.addAll(Splitter.on(",").splitToList(rulesArray.getJSONObject(i).getString("ruleContent")));
        }
        return accountRestrictRuleList;
    }

}
